import gym
import numpy as np
import torch
import wandb
import datetime
import os
import os.path as osp
from typing import Tuple

import yaml
import hydra
from hydra.utils import get_original_cwd
import random
from omegaconf import DictConfig, OmegaConf
from transformers import DistilBertTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from diffgro.lisa.expert_dataset import ExpertDataset
from diffgro.lisa.hrl_model import HRLModel
from diffgro.lisa.trainer import Trainer
from diffgro.lisa.eval import get_action

from diffgro.utils import make_dir, print_r, print_y, print_b, save_video
from diffgro.common.evaluations import evaluate, evaluate_complex, evaluate_context
from train import eval_save


def make_env(args):
    print_r(f"<< Making Environment for {args.env_name}... >>")
    domain_name, task_name = args.env_name.split(".")
    if domain_name == "metaworld":
        env = gym.make(task_name, seed=args.seed)
    elif domain_name == "metaworld_complex":
        task_list = task_name.split("-")[:-1]
        if task_list[-1] == "variant":
            env = gym.make(
                "complex-variant-v2", seed=args.seed, task_list=task_list[:-1]
            )
        else:
            env = gym.make("complex-v2", seed=args.seed, task_list=task_list)
    print_y(
        f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}"
    )

    from diffgro.environments.variant import Categorical, VariantSpace

    if domain_name == "metaworld":
        print_r(f"Goal Resistance: {args.goal_resistance}")
        env.variant_space.variant_config["goal_resistance"] = Categorical(
            a=[args.goal_resistance]
        )
    elif domain_name == "metaworld_complex":
        env.variant_space.variant_config["goal_resistance"] = VariantSpace(
            {
                "handle": Categorical(a=[0]),
                "button": Categorical(a=[0]),
                "drawer": Categorical(a=[0]),
                "lever": Categorical(a=[0]),
                "door": Categorical(a=[0]),
            }
        )
    else:
        raise NotImplementedError

    env.variant_space.variant_config["arm_speed"] = Categorical(a=[1.0])
    env.variant_space.variant_config["wind_xspeed"] = Categorical(a=[1.0])
    env.variant_space.variant_config["wind_yspeed"] = Categorical(a=[1.0])

    return env, domain_name, task_name


class LISAWrapper:
    def __init__(self, env, model, args, cfg):
        self.env = env
        self.model = model

        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

        self.act_dim = env.action_space.shape[-1]
        self.state_dim = env.observation_space.shape[-1] + self.act_dim
        if env.domain_name == "metaworld_complex":
            self.state_dim += 4
        self.option_dim = model.option_selector.option_dim

        self.K = args.trainer.K
        self.device = cfg.trainer.device

    def predict(self, obs, deterministic):
        state_dim, act_dim, option_dim = self.state_dim, self.act_dim, self.option_dim
        method = self.model.method
        horizon = self.model.horizon
        device = self.device

        if self.env.domain_name == "metaworld_complex":
            unwrapped_env = self.env.env.env
            subtask = unwrapped_env.skill_list[unwrapped_env.mode]
            skill_lang = ["puck", "drawer", "button", "stick"]
            language = skill_lang.index(subtask)
            language = np.eye(len(skill_lang))[language]
            obs = np.concatenate([obs, language], axis=0)

        if self.states is None:
            cur_state = np.concatenate([obs, np.zeros(act_dim)])
            cur_state = torch.from_numpy(cur_state)
            self.states = cur_state.reshape(1, state_dim).to(
                device=device, dtype=torch.float32
            )
            self.t = 0
        else:
            state = np.concatenate([obs, self.prev_action])
            cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
            self.states = torch.cat([self.states, cur_state], dim=0).float()
            self.timesteps = torch.cat(
                [
                    self.timesteps,
                    torch.ones((1, 1), device=device, dtype=torch.long) * self.t,
                ],
                dim=1,
            )

        ############################################
        self.actions = torch.cat(
            [self.actions, torch.zeros((1, act_dim), device=device)], dim=0
        )
        self.options = torch.cat(
            [self.options, torch.zeros((1, option_dim), device=device)], dim=0
        )
        action, self.option, self.states, self.actions, self.timesteps, self.options = (
            get_action(
                self.model,
                self.states,
                self.actions,
                self.options,
                self.timesteps,
                self.cls_embeddings,
                self.word_embeddings,
                self.options_list,
                cur_state,
                self.option,
                self.t,
                horizon,
                self.K,
                method,
                state_dim,
                act_dim,
                option_dim,
                device,
            )
        )
        action = torch.clamp(
            action,
            torch.from_numpy(self.env.action_space.low).to(device),
            torch.from_numpy(self.env.action_space.high).to(device),
        )
        self.actions[-1] = action
        action = action.detach().cpu().numpy()

        self.t += 1
        self.prev_action = action
        return action, None, {}

    def reset(self):
        lang = " ".join(
            [e for e in self.env.task_name.split("-") if e not in ["variant", "v2"]]
        )
        print(lang)

        lm_input = self.tokenizer(
            text=[lang], add_special_tokens=True, return_tensors="pt", padding=True
        ).to(device=self.device)
        with torch.no_grad():
            lm_embeddings = self.model.lm(
                lm_input["input_ids"], lm_input["attention_mask"]
            ).last_hidden_state
            self.cls_embeddings = lm_embeddings[:, 0, :]
            self.word_embeddings = lm_embeddings[:, 1:, :]

        self.options_list = []
        self.option = None

        self.actions = torch.zeros(
            (0, self.act_dim), device=self.device, dtype=torch.float32
        )
        self.timesteps = torch.tensor(0, device=self.device, dtype=torch.long).reshape(
            1, 1
        )
        self.options = torch.zeros(
            (0, self.option_dim), device=self.device, dtype=torch.float32
        )
        self.states = None


def test(cfg):
    domain_name, task_name = cfg.env.name.split(".")
    if task_name == "all":
        dataset_folder = cfg.train_dataset.expert_location.split(".")[0]
        task_list = os.listdir(dataset_folder)
    else:
        task_list = [task_name]

    save_path = osp.join("./lisa", domain_name, task_name)
    make_dir(save_path)

    # Load arguments
    checkpoint = torch.load(cfg.checkpoint_path)
    args = checkpoint["config"]
    max_length = checkpoint["train_dataset_max_length"]
    args.eval = cfg.eval
    args.checkpoint_path = cfg.checkpoint_path
    device = cfg.trainer.device

    args.method = args.model.name
    print("=" * 50)
    print(f"Starting evaluation: {args.env.name}")
    print(f"{args.trainer.num_eval_episodes} trajectories")
    print("=" * 50)

    state_dim = args.env.state_dim
    action_dim = args.env.action_dim

    # Create model
    args.option_selector.option_transformer.max_length = int(max_length)
    args.option_selector.option_transformer.max_ep_len = (
        args.env.eval_episode_factor * int(max_length)
    )
    option_selector_args = dict(args.option_selector)
    option_selector_args["state_dim"] = state_dim
    option_selector_args["option_dim"] = args.option_dim
    option_selector_args["codebook_dim"] = args.codebook_dim
    state_reconstructor_args = dict(args.state_reconstructor)
    lang_reconstructor_args = dict(args.lang_reconstructor)
    decision_transformer_args = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "option_dim": args.option_dim,
        "discrete": args.env.discrete,
        "hidden_size": args.dt.hidden_size,
        "use_language": False,
        "use_options": True,
        "option_il": args.dt.option_il,
        "max_length": args.model.K,
        "max_ep_len": args.env.eval_episode_factor * max_length,
        "action_tanh": False,
        "n_layer": args.dt.n_layer,
        "n_head": args.dt.n_head,
        "n_inner": 4 * args.dt.hidden_size,
        "activation_function": args.dt.activation_function,
        "n_positions": args.dt.n_positions,
        "n_ctx": args.dt.n_positions,
        "resid_pdrop": args.dt.dropout,
        "attn_pdrop": args.dt.dropout,
    }
    hrl_model_args = dict(args.model)
    iq_args = cfg.iq
    model = HRLModel(
        option_selector_args,
        state_reconstructor_args,
        lang_reconstructor_args,
        decision_transformer_args,
        iq_args,
        device,
        **hrl_model_args,
    )
    model = model.to(device=device)
    model.load_state_dict(checkpoint["model"])
    tot_success = []

    num_episodes = cfg.trainer.num_eval_episodes
    for task in task_list:
        temp_args = OmegaConf.create(
            {
                "env_name": f"{domain_name}.{task}",
                "goal_resistance": cfg.goal_resistance,
                "seed": cfg.seed,
            }
        )
        env, domain_name, env_name = make_env(temp_args)
        env.domain_name = domain_name
        env.task_name = task

        contexts = [None]
        if cfg.context:
            context_path = osp.join(
                get_original_cwd(), "config", "contexts", domain_name, f"{task}.yml"
            )
            with open(context_path) as f:
                contexts = yaml.load(f, Loader=yaml.FullLoader)[task]["text"]
                contexts = contexts[: 4 if domain_name == "metaworld" else 2]

        with torch.no_grad():
            for context in contexts:
                if domain_name == "metaworld":
                    success = evaluate(
                        LISAWrapper(env, model, args, cfg),
                        env,
                        domain_name,
                        task,
                        num_episodes,
                        True,
                        cfg.video,
                        save_path,
                        context=context,
                    )
                    if context is not None:
                        success, _ = success
                else:
                    success = evaluate_complex(
                        LISAWrapper(env, model, args, cfg),
                        env,
                        domain_name,
                        task,
                        num_episodes,
                        True,
                        cfg.video,
                        save_path,
                        context=context,
                    )
                tot_success.extend(success)

    if len(task_list) > 1:
        eval_save(tot_success, save_path)


def train(args):
    device = args.trainer.device

    args.method = args.model.name
    exp_name = (
        f"{args.project_name}-{args.train_dataset.num_trajectories}-{args.method}"
    )
    args.savepath = f'{args.hydra_base_dir}/{args.savedir}/{exp_name}-{datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'

    name = args.train_dataset.expert_location.split("/")[-1].split(".")[0]
    if args.wandb:
        wandb.init(project="diffgro", name=f"lisa_lc_{name}")

    if not os.path.isdir(args.savepath):
        os.makedirs(args.savepath, exist_ok=True)

    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

    # K = args['K']
    batch_size = args.batch_size

    train_dataset_args = dict(args.train_dataset)
    if not os.path.exists(train_dataset_args["expert_location"]):
        import pickle

        dataset_folder = train_dataset_args["expert_location"].split(".")[0]
        task_list = os.listdir(dataset_folder)

        dataset = {
            "states": [],
            "actions": [],
            "rewards": [],
            "lengths": [],
            "language": [],
            "dones": [],
        }
        skill_langs = ["push_puck", "close_drawer", "press_button", "peg_insert"]

        for task in task_list:
            lang = " ".join([e for e in task.split("-") if e not in ["variant", "v2"]])
            file_path = os.path.join(dataset_folder, task, "trajectory")
            for fp in os.listdir(file_path):
                with open(os.path.join(file_path, fp), "rb") as f:
                    traj = pickle.load(f)

                    states = traj["observations"]
                    if "complex" in dataset_folder:
                        language = [
                            skill_langs.index(lang) for lang in traj["skill_langs"]
                        ]
                        language = language + [language[-1]]
                        language = np.eye(len(skill_langs))[language]
                        states = np.concatenate([states, language], axis=1)

                    dataset["states"].append(states)
                    dataset["actions"].append(traj["actions"])
                    dataset["rewards"].append(traj["rewards"])
                    dataset["lengths"].append(len(traj["observations"]))
                    dataset["language"].append(lang)
                    dataset["dones"].append(traj["terminals"])

        print(len(dataset["states"]))
        with open(train_dataset_args["expert_location"], "wb") as f:
            pickle.dump(dataset, f)

    train_dataset = ExpertDataset(**train_dataset_args)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    print("=" * 50)
    print(
        f"Starting new experiment: {args.env.name} {args.train_dataset.num_trajectories}"
    )
    print(
        f"{len(train_dataset)} trajectories, {train_dataset.total_timesteps} timesteps found"
    )
    print("=" * 50)

    state_dim = args.env.state_dim
    action_dim = args.env.action_dim

    print(f"--> Train episode max length: {train_dataset.max_length}")
    args.option_selector.option_transformer.max_length = int(train_dataset.max_length)
    args.option_selector.option_transformer.max_ep_len = (
        args.env.eval_episode_factor * int(train_dataset.max_length)
    )

    if args.model.horizon == "max":
        args.model.horizon = int(train_dataset.max_length)
    if args.model.K == "max":
        args.model.K = int(train_dataset.max_length)

    option_selector_args = dict(args.option_selector)
    option_selector_args["state_dim"] = state_dim
    option_selector_args["option_dim"] = args.option_dim
    option_selector_args["codebook_dim"] = args.codebook_dim
    state_reconstructor_args = dict(args.state_reconstructor)
    lang_reconstructor_args = dict(args.lang_reconstructor)
    lang_reconstructor_args["max_options"] = np.ceil(
        train_dataset.max_length / args.model.K
    ).astype(int)
    print(lang_reconstructor_args["max_options"])
    lang_reconstructor_args["option_dim"] = args.option_dim
    decision_transformer_args = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "option_dim": args.option_dim,
        "discrete": args.env.discrete,
        "hidden_size": args.dt.hidden_size,
        "use_language": False,
        "use_options": True,
        "option_il": args.dt.option_il,
        "predict_q": args.use_iq,
        "max_length": args.model.K,
        "max_ep_len": args.env.eval_episode_factor * train_dataset.max_length,
        "n_layer": args.dt.n_layer,
        "n_head": args.dt.n_head,
        "activation_function": args.dt.activation_function,
        "n_positions": args.dt.n_positions,
        "n_ctx": args.dt.n_positions,
        "resid_pdrop": args.dt.dropout,
        "attn_pdrop": args.dt.dropout,
        "no_states": args.dt.no_states,
        "no_actions": args.dt.no_actions,
    }
    hrl_model_args = dict(args.model)
    iq_args = args.iq

    model = HRLModel(
        option_selector_args,
        state_reconstructor_args,
        lang_reconstructor_args,
        decision_transformer_args,
        iq_args,
        device,
        **hrl_model_args,
    )

    start_iter = 1
    if args.resume:
        args.warmup_steps = 0
        # checkpoint = trainer.load(args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint["model"])
        start_iter = checkpoint["iter_num"] + 1
        assert (
            train_dataset.max_length == checkpoint["train_dataset_max_length"]
        ), f"Expected max length of dataset to be {train_dataset.max_length} but got {checkpoint['train_dataset_max_length']}"

    if args.load_options:
        checkpoint = torch.load(args.checkpoint_path)
        checkpoint = checkpoint["model"]
        state_dict = {
            k: v for k, v in checkpoint.items() if k.startswith("option_selector.Z")
        }
        loaded = model.load_state_dict(state_dict, strict=False)
        assert loaded.unexpected_keys == []  ## simple check
        if args.freeze_loaded_options:
            for name, param in model.named_parameters():
                if name.startswith("option_selector.Z"):
                    param.requires_grad = False
            assert (
                not model.option_selector.Z.project_out.bias.requires_grad
            )  ## simple check

    if args.parallel:
        model = torch.nn.DataParallel(model).to(device)
    else:
        model = model.to(device=device)

    # Setting up the optimizer
    params = [(k, v) for k, v in model.named_parameters() if v.requires_grad]
    # setting different learning rates for the LM part, OS part and other parts
    lm_params = {
        "params": [v for k, v in params if k.startswith("lm.")],
        "lr": args.lm_learning_rate,
    }
    os_params = {
        "params": [v for k, v in params if k.startswith("option_selector.")],
        "lr": args.os_learning_rate,
    }
    other_params = {
        "params": [
            v
            for k, v in params
            if not k.startswith("lm.") and not k.startswith("option_selector.")
        ]
    }
    # for the option selector need separate lr?
    optimizer = torch.optim.AdamW(
        [other_params, lm_params, os_params],
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )

    def adjust_lr(steps):
        if steps < args.warmup_steps:
            return min((steps + 1) / args.warmup_steps, 1)
        num_decays = (steps + 1) // args.decay_steps
        return args.lr_decay ** (num_decays)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, adjust_lr)

    trainer_args = dict(args.trainer)

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        optimizer=optimizer,
        train_loader=train_loader,
        scheduler=scheduler,
        skip_words=args.env.skip_words,
        **trainer_args,
    )

    # Training loop
    for iter_num in range(start_iter, start_iter + args.max_iters):
        outputs = trainer.train_iteration(
            iter_num=iter_num, print_logs=True, eval_render=args.video
        )

        if args.wandb and iter_num % args.log_interval == 0:
            wandb.log(outputs, step=iter_num)

        if iter_num % args.save_interval == 0:
            trainer.save(
                iter_num, f"{args.savepath}/model_{name}_{iter_num}.ckpt", args
            )

    trainer.save(iter_num, f"{args.savepath}/model_{name}_final.ckpt", args)


def get_args(cfg: DictConfig):
    cfg.trainer.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    cfg.hydra_base_dir = os.getcwd()
    return cfg


@hydra.main(config_path="../diffgro/lisa/conf", config_name="config")
def main(cfg: DictConfig):
    args = get_args(cfg)

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.trainer.device)
    if device.type == "cuda" and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if args.eval:
        test(cfg)
        return

    train(args)


if __name__ == "__main__":
    main()
